19676a
@@ -43,6 +43,7 @@
 import org.apache.hadoop.hive.ql.metadata.Partition;
 import org.apache.hadoop.hive.ql.metadata.Table;
 import org.apache.hadoop.hive.ql.optimizer.PrunerOperatorFactory.FilterPruner;
+import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveExcept;
 import org.apache.hadoop.hive.ql.optimizer.ppr.PartitionPruner;
 import org.apache.hadoop.hive.ql.parse.ParseContext;
 import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
@@ -82,54 +83,34 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
     }
   }
 
-  public class FixedBucketPartitionWalker extends FilterPruner {
-
-    @Override
-    protected void generatePredicate(NodeProcessorCtx procCtx,
-        FilterOperator fop, TableScanOperator top) throws SemanticException {
-      FixedBucketPruningOptimizerCtxt ctxt = ((FixedBucketPruningOptimizerCtxt) procCtx);
-      Table tbl = top.getConf().getTableMetadata();
-      if (tbl.getNumBuckets() > 0) {
-        final int nbuckets = tbl.getNumBuckets();
-        ctxt.setNumBuckets(nbuckets);
-        ctxt.setBucketCols(tbl.getBucketCols());
-        ctxt.setSchema(tbl.getFields());
-        if (tbl.isPartitioned()) {
-          // Run partition pruner to get partitions
-          ParseContext parseCtx = ctxt.pctx;
-          PrunedPartitionList prunedPartList;
-          try {
-            String alias = (String) parseCtx.getTopOps().keySet().toArray()[0];
-            prunedPartList = PartitionPruner.prune(top, parseCtx, alias);
-          } catch (HiveException e) {
-            throw new SemanticException(e.getMessage(), e);
-          }
-          if (prunedPartList != null) {
-            ctxt.setPartitions(prunedPartList);
-            for (Partition p : prunedPartList.getPartitions()) {
-              if (nbuckets != p.getBucketCount()) {
-                // disable feature
-                ctxt.setNumBuckets(-1);
-                break;
-              }
-            }
-          }
-        }
-      }
-    }
-  }
-
   public static class BucketBitsetGenerator extends FilterPruner {
 
     @Override
     protected void generatePredicate(NodeProcessorCtx procCtx,
-        FilterOperator fop, TableScanOperator top) throws SemanticException {
+        FilterOperator fop, TableScanOperator top) throws SemanticException{
       FixedBucketPruningOptimizerCtxt ctxt = ((FixedBucketPruningOptimizerCtxt) procCtx);
-      if (ctxt.getNumBuckets() <= 0 || ctxt.getBucketCols().size() != 1) {
+      Table tbl = top.getConf().getTableMetadata();
+      int numBuckets = tbl.getNumBuckets();
+      if (numBuckets <= 0 || tbl.getBucketCols().size() != 1) {
         // bucketing isn't consistent or there are >1 bucket columns
         // optimizer does not extract multiple column predicates for this
         return;
       }
+
+      if (tbl.isPartitioned()) {
+        // Make sure all the partitions have same bucket count.
+        PrunedPartitionList prunedPartList =
+            PartitionPruner.prune(top, ctxt.pctx, top.getConf().getAlias());
+        if (prunedPartList != null) {
+          for (Partition p : prunedPartList.getPartitions()) {
+            if (numBuckets != p.getBucketCount()) {
+              // disable feature
+              return;
+            }
+          }
+        }
+      }
+      
       ExprNodeGenericFuncDesc filter = top.getConf().getFilterExpr();
       if (filter == null) {
         return;
@@ -139,9 +120,9 @@
protected void generatePredicate(NodeProcessorCtx procCtx,
       if (sarg == null) {
         return;
       }
-      final String bucketCol = ctxt.getBucketCols().get(0);
+      final String bucketCol = tbl.getBucketCols().get(0);
       StructField bucketField = null;
-      for (StructField fs : ctxt.getSchema()) {
+      for (StructField fs : tbl.getFields()) {
         if(fs.getFieldName().equals(bucketCol)) {
           bucketField = fs;
         }
@@ -221,7 +202,7 @@
protected void generatePredicate(NodeProcessorCtx procCtx,
         }
       }
       // invariant: bucket-col IN literals of type bucketField
-      BitSet bs = new BitSet(ctxt.getNumBuckets());
+      BitSet bs = new BitSet(numBuckets);
       bs.clear();
       PrimitiveObjectInspector bucketOI = (PrimitiveObjectInspector)bucketField.getFieldObjectInspector();
       PrimitiveObjectInspector constOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(bucketOI.getPrimitiveCategory());
@@ -237,22 +218,22 @@
protected void generatePredicate(NodeProcessorCtx procCtx,
         }
         Object convCols[] = new Object[] {conv.convert(literal)};
         int n = bucketingVersion == 2 ?
-            ObjectInspectorUtils.getBucketNumber(convCols, new ObjectInspector[]{constOI}, ctxt.getNumBuckets()) :
-            ObjectInspectorUtils.getBucketNumberOld(convCols, new ObjectInspector[]{constOI}, ctxt.getNumBuckets());
+            ObjectInspectorUtils.getBucketNumber(convCols, new ObjectInspector[]{constOI}, numBuckets) :
+            ObjectInspectorUtils.getBucketNumberOld(convCols, new ObjectInspector[]{constOI}, numBuckets);
         bs.set(n);
         if (bucketingVersion == 1 && ctxt.isCompat()) {
           int h = ObjectInspectorUtils.getBucketHashCodeOld(convCols, new ObjectInspector[]{constOI});
           // -ve hashcodes had conversion to positive done in different ways in the past
           // abs() is now obsolete and all inserts now use & Integer.MAX_VALUE 
           // the compat mode assumes that old data could've been loaded using the other conversion
-          n = ObjectInspectorUtils.getBucketNumber(Math.abs(h), ctxt.getNumBuckets());
+          n = ObjectInspectorUtils.getBucketNumber(Math.abs(h), numBuckets);
           bs.set(n);
         }
       }
-      if (bs.cardinality() < ctxt.getNumBuckets()) {
+      if (bs.cardinality() < numBuckets) {
         // there is a valid bucket pruning filter
         top.getConf().setIncludedBuckets(bs);
-        top.getConf().setNumBuckets(ctxt.getNumBuckets());
+        top.getConf().setNumBuckets(numBuckets);
       }
     }
 
@@ -339,19 +320,9 @@
public ParseContext transform(ParseContext pctx) throws SemanticException {
     FixedBucketPruningOptimizerCtxt opPartWalkerCtx = new FixedBucketPruningOptimizerCtxt(compat,
         pctx);
 
-    // Retrieve all partitions generated from partition pruner and partition
-    // column pruner
+    // walk operator tree to create expression tree for filter buckets
     PrunerUtils.walkOperatorTree(pctx, opPartWalkerCtx,
-        new FixedBucketPartitionWalker(), new NoopWalker());
-
-    if (opPartWalkerCtx.getNumBuckets() < 0) {
-      // bail out
-      return pctx;
-    } else {
-      // walk operator tree to create expression tree for filter buckets
-      PrunerUtils.walkOperatorTree(pctx, opPartWalkerCtx,
-          new BucketBitsetGenerator(), new NoopWalker());
-    }
+        new BucketBitsetGenerator(), new NoopWalker());
 
     return pctx;
   }
